from typing import Optional
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
from torch import Tensor
from torch.nn import functional as F
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

from src.models.denoiser.attention_fusion import CrossAttentionFusion
from einops import rearrange
from src.models.denoiser.mlp import GatedMLP
from src.models.denoiser.scanning_orders import SCAN_ZOO, local_reverse, local_scan, reverse_permut_np
from src.models.denoiser.wavelet_layer import DWT_2D, IDWT_2D
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"

class DiMBlockCombined(nn.Module):
    def __init__(
        self,
        dim,
        mixer_cls,
        norm_cls=nn.LayerNorm,
        fused_add_norm=False,
        residual_in_fp32=False,
        drop_path=0.0,
        reverse=False,
        transpose=False,
        scanning_continuity=False,
        use_gated_mlp=True,
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.reverse = reverse
        self.transpose = transpose
        self.scanning_continuity = scanning_continuity

        self.norm = norm_cls(dim)
        self.spatial_mamba = DiMBlockRaw(
            dim // 2,
            mixer_cls,
            norm_cls=nn.Identity,
            drop_path=0.0,
            fused_add_norm=False,
            residual_in_fp32=residual_in_fp32,
            reverse=reverse,
            transpose=transpose,
            scanning_continuity=scanning_continuity,
            c_dim=dim,
        )

        self.freq_mamba = WaveDiMBlock(
            dim // 2,
            mixer_cls,
            norm_cls=nn.Identity,
            drop_path=0.0,
            fused_add_norm=False,
            residual_in_fp32=residual_in_fp32,
            reverse=False,
            transpose=reverse,  # transpose, # disable if only left2right scanning is used
            scanning_continuity=scanning_continuity,
            no_ffn=True,
            c_dim=dim,
            num_wavelet_lv=2,
        )

        self.proj = CrossAttentionFusion(dim, num_heads=8, qkv_bias=True, swap_k=False)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

        self.norm_2 = norm_cls(dim)

        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 3 * dim, bias=True))
        mlp_hidden_dim = int(dim * 4)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        if use_gated_mlp:
            self.mlp = GatedMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        else:
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)

    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        c: Optional[Tensor] = None,
        inference_params=None,
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))

        """
        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)

            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            if residual is None:
                hidden_states, residual = fused_add_norm_fn(
                    hidden_states,
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )
            else:
                hidden_states, residual = fused_add_norm_fn(
                    self.drop_path(hidden_states),
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )

        x1, x2 = hidden_states.chunk(2, dim=2)
        x1, _ = self.spatial_mamba(x1, None, c, inference_params)
        x2, _ = self.freq_mamba(x2, None, c, inference_params)
        if isinstance(self.proj, CrossAttentionFusion):
            x = self.proj(x1, x2)
        else:
            x = torch.cat((x1, x2), dim=2)
            x = self.proj(x)

        # FFN
        hidden_states = hidden_states + x
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(3, dim=1)
        hidden_states = hidden_states + gate_mlp.unsqueeze(1) * self.mlp(
            modulate(self.norm_2(hidden_states), shift_mlp, scale_mlp)
        )

        return hidden_states, residual

class DiMBlockRaw(nn.Module):
    def __init__(
        self,
        dim,
        mixer_cls,
        norm_cls=nn.LayerNorm,
        fused_add_norm=False,
        residual_in_fp32=False,
        drop_path=0.0,
        reverse=False,
        transpose=False,
        scanning_continuity=False,
        c_dim=None,
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.reverse = reverse
        self.transpose = transpose
        self.scanning_continuity = scanning_continuity
        c_dim = dim if c_dim is None else c_dim

        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(c_dim, 3 * dim, bias=True))

    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        c: Optional[Tensor] = None,
        inference_params=None,
    ):
        r"""Pass the input through the encoder layer.

        """
        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)

            hidden_states = (
                self.norm(residual)
                if isinstance(self.norm, nn.Identity)
                else self.norm(residual.to(dtype=self.norm.weight.dtype))
            )
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            if residual is None:
                hidden_states, residual = fused_add_norm_fn(
                    hidden_states,
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )
            else:
                hidden_states, residual = fused_add_norm_fn(
                    self.drop_path(hidden_states),
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )

        l = hidden_states.shape[1]
        h = w = int(np.sqrt(l))
        if self.transpose:
            hidden_states = rearrange(hidden_states, "n (h w) c -> n (w h) c", h=h, w=w)

        if self.scanning_continuity:
            hidden_states = rearrange(hidden_states.clone(), "n (w h) c -> n c w h", h=h, w=w)
            hidden_states[:, :, 1::2] = hidden_states[:, :, 1::2].flip(-1)
            hidden_states = rearrange(hidden_states, "n c w h -> n (w h) c", h=h, w=w)

        if self.reverse:
            hidden_states = hidden_states.flip(1)

        shift_ssm, scale_ssm, gate_ssm = self.adaLN_modulation(c).chunk(3, dim=1)
        hidden_states = hidden_states + gate_ssm.unsqueeze(1) * self.mixer(
            modulate(hidden_states, shift_ssm, scale_ssm), c, inference_params=inference_params
        )

        # transform back
        if self.reverse:
            hidden_states = hidden_states.flip(1)

        if self.scanning_continuity:
            hidden_states = rearrange(hidden_states.clone(), "n (w h) c -> n c w h", h=h, w=w)
            hidden_states[:, :, 1::2] = hidden_states[:, :, 1::2].flip(-1)
            hidden_states = rearrange(hidden_states, "n c w h -> n (w h) c", h=h, w=w)

        if self.transpose:
            hidden_states = rearrange(hidden_states, "n (h w) c -> n (w h) c", h=h, w=w)

        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
    
class WaveDiMBlock(nn.Module):
    def __init__(
        self,
        dim,
        mixer_cls,
        norm_cls=nn.LayerNorm,
        fused_add_norm=False,
        residual_in_fp32=False,
        drop_path=0.0,
        reverse=False,
        transpose=False,
        scanning_continuity=False,
        skip=False,
        no_ffn=False,
        c_dim=None,
        window_scan=True,
        num_wavelet_lv=2,
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.reverse = reverse
        self.transpose = transpose
        self.scanning_continuity = scanning_continuity
        self.no_ffn = no_ffn
        self.window_scan = window_scan
        self.num_wavelet_lv = num_wavelet_lv
        c_dim = dim if c_dim is None else c_dim

        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)

        self.dwt = DWT_2D(wave="haar")
        self.idwt = IDWT_2D(wave="haar")

        # w/o FFN
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(c_dim, 6 * dim if not self.no_ffn else 3 * dim, bias=True)
        )
        if not self.no_ffn:
            self.norm_2 = norm_cls(dim)

            mlp_hidden_dim = int(dim * 4)
            approx_gelu = lambda: nn.GELU(approximate="tanh")
            self.mlp = GatedMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)

    def _dwt_fast(self, x):
        # support only two-levels wavelet
        x = rearrange(x, "b (h w) c -> b c h w", h=int(np.sqrt(x.size(1))))
        subbands = self.dwt(x)  # xll, xlh, xhl, xhh where each has shape of [b, c, h/2, w/2]
        scale = 2**self.num_wavelet_lv
        patch_size = scale
        if self.num_wavelet_lv > 1:
            out = (self.dwt(subbands) / scale).chunk(patch_size * patch_size, dim=1)
            indices = []
            for i in range(patch_size * patch_size):
                indices.append(i % 4 * patch_size + i // 4)
            out = torch.cat([out[i] for i in indices], dim=1)
        else:
            out = subbands / scale
        return rearrange(out, "b (c p1 p2) h w -> b (h p1 w p2) c", p1=patch_size, p2=patch_size)  # [b, c, h, w]

    def _idwt_fast(self, x):
        scale = 2**self.num_wavelet_lv
        patch_size = scale
        lowest_size = int(np.sqrt(x.size(1))) // patch_size
        subbands = rearrange(
            x * scale, "b (h p1 w p2) c -> b (c p1 p2) h w", p1=patch_size, p2=patch_size, h=lowest_size
        ).chunk(patch_size * patch_size, dim=1)
        if self.num_wavelet_lv > 1:
            indices = []
            for i in range(patch_size * patch_size):
                indices.append(i % 4 * patch_size + i // 4)
            subbands = torch.cat([subbands[i] for i in indices], dim=1)
            out = self.idwt(subbands)
            out = self.idwt(out)
        else:
            out = self.idwt(torch.cat(subbands, dim=1))
        return rearrange(out, "b c h w -> b (h w) c")  # [b, c, h, w]

    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        c: Optional[Tensor] = None,
        inference_params=None,
    ):
        r"""Pass the input through the encoder layer.

        """
        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)

            hidden_states = (
                self.norm(residual)
                if isinstance(self.norm, nn.Identity)
                else self.norm(residual.to(dtype=self.norm.weight.dtype))
            )
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            if residual is None:
                hidden_states, residual = fused_add_norm_fn(
                    hidden_states,
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )
            else:
                hidden_states, residual = fused_add_norm_fn(
                    self.drop_path(hidden_states),
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )

        l = hidden_states.shape[1]
        h = w = int(np.sqrt(l))
        hidden_states = self._dwt_fast(hidden_states).contiguous()
        patch_size = int(2**self.num_wavelet_lv)  # old: 4
        if self.window_scan:
            column_first = True if self.transpose else False
            hidden_states = local_scan(
                hidden_states, w=w // patch_size, H=h, W=w, column_first=column_first
            ).contiguous()  # Resolve fixed window size
        else:
            if self.transpose:
                hidden_states = rearrange(hidden_states, "n (h w) c -> n (w h) c", h=h, w=w)

        if self.scanning_continuity:
            hidden_states = rearrange(hidden_states.clone(), "n (w h) c -> n c w h", h=h, w=w)
            hidden_states[:, :, 1::2] = hidden_states[:, :, 1::2].flip(-1)
            hidden_states = rearrange(hidden_states, "n c w h -> n (w h) c", h=h, w=w)

        if self.reverse:
            hidden_states = hidden_states.flip(1)

        if not self.no_ffn:
            shift_ssm, scale_ssm, gate_ssm, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
            hidden_states = hidden_states + gate_ssm.unsqueeze(1) * self.mixer(
                modulate(hidden_states, shift_ssm, scale_ssm), c, inference_params=inference_params
            )
            hidden_states = hidden_states + gate_mlp.unsqueeze(1) * self.mlp(
                modulate(self.norm_2(hidden_states), shift_mlp, scale_mlp)
            )
        else:
            shift_ssm, scale_ssm, gate_ssm = self.adaLN_modulation(c).chunk(3, dim=1)
            hidden_states = hidden_states + gate_ssm.unsqueeze(1) * self.mixer(
                modulate(hidden_states, shift_ssm, scale_ssm), c, inference_params=inference_params
            )

        # transform back
        if self.reverse:
            hidden_states = hidden_states.flip(1)

        if self.scanning_continuity:
            hidden_states = rearrange(hidden_states.clone(), "n (w h) c -> n c w h", h=h, w=w)
            hidden_states[:, :, 1::2] = hidden_states[:, :, 1::2].flip(-1)
            hidden_states = rearrange(hidden_states, "n c w h -> n (w h) c", h=h, w=w)

        if self.window_scan:
            hidden_states = local_reverse(hidden_states, w=w // patch_size, H=h, W=w, column_first=column_first)
        else:
            if self.transpose:
                hidden_states = rearrange(hidden_states, "n (h w) c -> n (w h) c", h=h, w=w)
        hidden_states = self._idwt_fast(hidden_states).contiguous()

        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
    
def gen_paths(N, scan_type,depth=None):
    path_type = scan_type.split("_")[0]
    num_paths = int(scan_type.split("_")[1])
    path_gen_fn = SCAN_ZOO[path_type]
    zz_paths = path_gen_fn(N)[:num_paths]

    zz_paths_rev = [reverse_permut_np(x) for x in zz_paths]
    zz_paths = [torch.from_numpy(x)[None,] for x in zz_paths]
    zz_paths_rev = [torch.from_numpy(x)[None,] for x in zz_paths_rev]
    zz_paths = torch.cat(zz_paths * depth, dim=0)
    zz_paths_rev = torch.cat(zz_paths_rev * depth, dim=0)

    assert len(zz_paths) == len(zz_paths_rev), f"{len(zz_paths)} != {len(zz_paths_rev)}"

    block_kwargs = {}
    block_kwargs["zigzag_paths"] = zz_paths
    block_kwargs["zigzag_paths_reverse"] = zz_paths_rev
    block_kwargs["scan_type"] = scan_type
    return block_kwargs